{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "55a59c39",
   "metadata": {},
   "source": [
    "## This notebook process and generate the experimental results for the experimental section of the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c5ffd26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.io import loadmat\n",
    "import tensorflow as tf\n",
    "import os\n",
    "\n",
    "# These are used for plotting later on\n",
    "from cycler import cycler\n",
    "import matplotlib as mpl\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Edit this line to point to the location of the data folder on your system.\n",
    "data_folder_prefix = './data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0329524",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the ground truth for the test dataset, this\n",
    "# would be the same image throughout\n",
    "prefix = data_folder_prefix + '/Simulated/R_Sweep/'\n",
    "test_images = np.load(prefix + 'test_images.npy')\n",
    "# get the phase of the complex objects\n",
    "test_images = np.angle(test_images)\n",
    "\n",
    "# array to make the input and output zeros on the edges\n",
    "Xs, Ys = np.mgrid[:256,:256]\n",
    "Xs = Xs - np.mean(Xs)\n",
    "Ys = Ys - np.mean(Ys)\n",
    "Rs = np.sqrt(Xs**2 + Ys**2)\n",
    "\n",
    "#test_images[:, Rs>108] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d58df19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_exp_global_phase(im1, im2, size=256):\n",
    "    \"\"\"correct the global phase factor for experimental iteartive reconstruction\n",
    "    \"\"\"\n",
    "    im1 = im1.reshape(1,-1)\n",
    "    im2 = im2.reshape(1,-1)\n",
    "    \n",
    "    mean1 = np.mean(im1)\n",
    "    mean2 = np.mean(im2)\n",
    "    \n",
    "    b = mean2 - mean1\n",
    "    \n",
    "    im = (im1 + b).reshape(size, size)\n",
    "    \n",
    "    return im"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0330202b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this plot the ground truth\n",
    "plt.rcParams['figure.figsize'] = [20, 20]\n",
    "fig = plt.figure()\n",
    "fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "for i in range(9):\n",
    "    ax = fig.add_subplot(3, 3, i +1)\n",
    "    ax.axes.xaxis.set_visible(False)\n",
    "    ax.axes.yaxis.set_visible(False)\n",
    "    plt.imshow(test_images[i], cmap = 'viridis')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aed4769f",
   "metadata": {},
   "source": [
    "## The cell below get the SSIM values for different photon level conditions using generative not pretrained network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37b994d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the MS-SSIM loss function\n",
    "ssim = tf.image.ssim_multiscale\n",
    "\n",
    "# generative not pretrained\n",
    "pcc_mean11 = []\n",
    "ssim_mean11 = []\n",
    "ssim_std11 = []\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100.0, 1e3]):\n",
    "    pic = []\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    alphas = [1, 1/2, 1/4, 1/8, 1/16, 1/32]\n",
    "    # loop over all the alpha values\n",
    "    for alpha in alphas:\n",
    "        # load the generative not pretrained network results\n",
    "        empty = 0\n",
    "        for string in ['2021-10-05', '2021-10-06']:\n",
    "            #print(str(string))\n",
    "            if os.path.isfile(str(string) + '-exp-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level)):\n",
    "                matfile = loadmat(str(string) + '-exp-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))\n",
    "                empty += 1\n",
    "                \n",
    "        if empty != 1:\n",
    "            print(\"missing files for R-%0.2f-alpha-%0.2f-photon-%0.2f\" %(R, alpha, photon_level))\n",
    "            break\n",
    "        rec_test_output = matfile['rec_test_output']\n",
    "\n",
    "        pic.append(rec_test_output)\n",
    "    \n",
    "        loss_ssim_list = []\n",
    "        for gen, true in zip(rec_test_output, test_images):\n",
    "            #true = true[20:-20, 20:-20]\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            \n",
    "            #print(np.shape(gen))\n",
    "            gen[Rs>108] = 0\n",
    "            #gen = gen[20:-20, 20:-20]\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            remove = 47\n",
    "            true = true[remove:-remove, remove:-remove, :]\n",
    "            gen = gen[remove:-remove, remove:-remove, :]\n",
    "            \n",
    "            ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            loss_ssim_list.append(ssim_data)   \n",
    "        \n",
    "        ssim_temp.append(np.mean(loss_ssim_list))\n",
    "        std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    idx = np.argmax(ssim_temp)\n",
    "    print('the best alpha is:', alphas[idx])\n",
    "    ssim_mean11.append(ssim_temp[idx])\n",
    "    ssim_std11.append(std_temp[idx])\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(pic[idx][i], cmap = 'viridis')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8d97b3f",
   "metadata": {},
   "source": [
    "## The cell below get the SSIM values for different photon level conditions using generative pretrained network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf46aade",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generative pretrained\n",
    "ssim_mean22 = []\n",
    "ssim_std22 = []\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100.0, 1e3]):\n",
    "    \n",
    "    pic = []\n",
    "    \n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    alphas = [1, 1/2, 1/4, 1/8, 1/16, 1/32]\n",
    "    \n",
    "    # loop over all the alpha values\n",
    "    for alpha in alphas:\n",
    "        # load the generative not pretrained network results\n",
    "        empty = 0\n",
    "        for string in ['2021-10-05', '2021-10-06']:\n",
    "            #print(str(string) + '-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))\n",
    "            if os.path.isfile(str(string) + '-exp-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level)):\n",
    "                #print(\"good\")\n",
    "                matfile = loadmat(str(string) + '-exp-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))\n",
    "                empty += 1\n",
    "                \n",
    "        if empty != 1:\n",
    "            print(\"missing files for R-%0.2f-alpha-%0.2f-photon-%0.2f\" %(R, alpha, photon_level))\n",
    "            break\n",
    "        rec_test_output = matfile['rec_test_output']\n",
    "\n",
    "        pic.append(rec_test_output)\n",
    "\n",
    "        loss_ssim_list = []\n",
    "        for gen, true in zip(rec_test_output, test_images):\n",
    "            #true = true[20:-20, 20:-20]\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            \n",
    "            #print(np.shape(gen))\n",
    "            gen[Rs>108] = 0\n",
    "            #gen = gen[20:-20, 20:-20]\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            remove = 47\n",
    "            true = true[remove:-remove, remove:-remove, :]\n",
    "            gen = gen[remove:-remove, remove:-remove, :]\n",
    "            \n",
    "            ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            loss_ssim_list.append(ssim_data)   \n",
    "        \n",
    "        ssim_temp.append(np.mean(loss_ssim_list))\n",
    "        std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    idx = np.argmax(ssim_temp)\n",
    "    print('the best alpha is:', alphas[idx])\n",
    "    ssim_mean22.append(ssim_temp[idx])\n",
    "    ssim_std22.append(std_temp[idx])\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(pic[idx][i], cmap = 'viridis')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0f4ab3d",
   "metadata": {},
   "source": [
    "## The cell below get the SSIM values for different photon level conditions using not generative not pretrained network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52a4352d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# not generative not pretrained\n",
    "ssim_mean33 = []\n",
    "ssim_std33 = []\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1000]):\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    matfile = loadmat('2021-10-05exp-not-pretrained-R-0.5-peak-' + str(photon_level) + '.mat')\n",
    "    rec_test_output = matfile['rec_test_output']\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        #true = true[20:-20, 20:-20]\n",
    "        true = tf.expand_dims(true, -1)\n",
    "\n",
    "        #print(np.shape(gen))\n",
    "        gen[Rs>108] = 0\n",
    "        #gen = gen[20:-20, 20:-20]\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "\n",
    "        remove = 47\n",
    "        true = true[remove:-remove, remove:-remove, :]\n",
    "        gen = gen[remove:-remove, remove:-remove, :]\n",
    "        \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    ssim_mean33.append(ssim_temp[0])\n",
    "    ssim_std33.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "193007d6",
   "metadata": {},
   "source": [
    "## The cell below get the SSIM values for different photon level conditions using not generative pretrained network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "295cbc3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# not generative pretrained\n",
    "ssim_mean44 = []\n",
    "ssim_std44 = []\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1000]):\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    matfile = loadmat('2021-10-05exp-pretrained-R-0.5-peak-' + str(photon_level) + '.mat')\n",
    "    rec_test_output = matfile['rec_test_output']\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        #true = true[20:-20, 20:-20]\n",
    "        true = tf.expand_dims(true, -1)\n",
    "\n",
    "        #print(np.shape(gen))\n",
    "        gen[Rs>108] = 0\n",
    "        #gen = gen[20:-20, 20:-20]\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "\n",
    "        remove = 47\n",
    "        true = true[remove:-remove, remove:-remove, :]\n",
    "        gen = gen[remove:-remove, remove:-remove, :]\n",
    "        \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    ssim_mean44.append(ssim_temp[0])\n",
    "    ssim_std44.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bbe5d88",
   "metadata": {},
   "source": [
    "## The cell below get the SSIM values for different photon level conditions using E2E network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "639509a1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# End-to-End\n",
    "ssim_mean55 = []\n",
    "ssim_std55 = []\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1e3]):\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "\n",
    "    rec_test_output = np.load('exp-End-to-End-test-output-R-0.5-photon-' + str(photon_level) + '.npy')\n",
    "    print(np.max(rec_test_output[0]), np.min(rec_test_output[0]))\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        #true = true[20:-20, 20:-20]\n",
    "        true = tf.expand_dims(true, -1)\n",
    "\n",
    "        #print(np.shape(gen))\n",
    "        gen[Rs>108] = 0\n",
    "        #gen = gen[20:-20, 20:-20]\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "\n",
    "        remove = 47\n",
    "        true = true[remove:-remove, remove:-remove, :]\n",
    "        gen = gen[remove:-remove, remove:-remove, :]\n",
    "        \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    ssim_mean55.append(ssim_temp[0])\n",
    "    ssim_std55.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e12d078d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ssim_mean55)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a03343b",
   "metadata": {},
   "source": [
    "## The cell below calculate the pixel offset of the experimental RPI dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ca56b4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "# calculating the correct pixel offset values\n",
    "loss = np.zeros((100, 400))\n",
    "\n",
    "iter_test = np.load('test-approx-%d-iter-%d-lr-%0.2f.npy' \n",
    "                             % (1e3, 100, 0.5)).astype(np.float32)[:, remove:-remove, remove:-remove]/np.pi\n",
    "\n",
    "for idx, (img1, img2) in tqdm(enumerate(zip(iter_test, test_images))):\n",
    "    for i in range(-10, 10):\n",
    "        for j in range(-10, 10):\n",
    "            imgcrop = img2[remove+i:-remove+i, remove+j:-remove+j]\n",
    "            im1_mean = np.mean(img1)\n",
    "            im2_mean = np.mean(imgcrop)\n",
    "            img1_final = img1 + (im2_mean - im1_mean)\n",
    "            loss[idx, (i+10)*20 + j+10] = np.sqrt(np.sum(np.square(img1_final - imgcrop)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "163c53b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.argmin(loss, axis=1)\n",
    "counts = np.bincount(a)\n",
    "for i in range(-10, 10):\n",
    "    for j in range(-10, 10):\n",
    "        if (i+10)*20 + j+10 == np.argmax(counts):\n",
    "            print(\"the offset x and y are:\", i, j)\n",
    "            shiftx, shifty = i, j"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ee2acb4",
   "metadata": {},
   "source": [
    "## The cell below get the SSIM values for different photon level conditions using iterative algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44c30243",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "# load the MS-SSIM loss function\n",
    "ssim = tf.image.ssim_multiscale\n",
    "\n",
    "iterative_std1 = []\n",
    "iterative_ssim1 = []\n",
    "\n",
    "lr = 0.5\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1e3]):\n",
    "    for i, iters in enumerate([100]):\n",
    "        iter_input = np.load('test-approx-%d-iter-%d-lr-%0.2f.npy' \n",
    "                             % (photon_level, iters, lr)).astype(np.float32)\n",
    "        #iter_input[:, Rs>128] = 0\n",
    "        \n",
    "        iterloss_ssim_list = []\n",
    "        gen_list = []\n",
    "        true_list = []\n",
    "        for gen, true in tqdm(zip(iter_input, test_images)):\n",
    "            #gen = np.expand_dims(gen, axis=-1)\n",
    "            #gen = denoise_tv_chambolle(gen, weight=0.1)[..., 0]\n",
    "            remove = 47\n",
    "            true = true[remove+shiftx:-remove+shiftx, remove+shifty:-remove+shifty]\n",
    "            gen = gen[remove:-remove, remove:-remove]\n",
    "            gen = gen/np.pi\n",
    "            \n",
    "            gen = find_exp_global_phase(gen, true, size=(128-remove)*2)\n",
    "            gen = np.clip(gen, np.min(test_images), np.max(test_images))\n",
    "            gen_list.append(gen)\n",
    "            true_list.append(true)\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            \n",
    "            #print(np.shape(true))\n",
    "            #print(np.shape(gen))\n",
    "            \n",
    "            ssim_data = ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            iterloss_ssim_list.append(ssim_data)\n",
    "            \n",
    "        plt.rcParams['figure.figsize'] = [20, 20]\n",
    "        fig = plt.figure()\n",
    "        fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "        for i in range(16):\n",
    "            ax = fig.add_subplot(4, 4, i +1)\n",
    "            ax.axes.xaxis.set_visible(False)\n",
    "            ax.axes.yaxis.set_visible(False)\n",
    "            plt.imshow(gen_list[i], cmap = 'viridis')\n",
    "            plt.clim(np.min(test_images), np.max(test_images))\n",
    "\n",
    "        iterative_ssim1.append(np.mean(iterloss_ssim_list))\n",
    "        iterative_std1.append(np.std(iterloss_ssim_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "683c5374",
   "metadata": {},
   "source": [
    "## The cell below generate plots in our paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd53793d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cycler import cycler\n",
    "import matplotlib as mpl\n",
    "\n",
    "\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "mpl.rcParams['axes.prop_cycle'] = cycler(color=['r', 'g', 'b', 'chocolate', 'olive'])\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [20, 16]\n",
    "\n",
    "\n",
    "labels = np.flip(['10$^0$ photons', '10$^1$ photons', '10$^2$ photons', '10$^3$ photons'])\n",
    "\n",
    "x = np.arange(len(labels))  # the label locations\n",
    "width = 0.13 # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "rects1 = ax.bar(x - width * 3/2, np.flip(ssim_mean44), width, yerr=np.flip(ssim_std44), label='Non-Generative', ecolor='black', capsize=5)\n",
    "#rects2 = ax.bar(x - width * 1, ssim_mean33, width, yerr=ssim_std33, label='Non-Generative-Not-Pretrain', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x - width * 1/2, np.flip(ssim_mean11), width, yerr=np.flip(ssim_std11), label='Generative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 1/2, np.flip(iterative_ssim1), width, yerr=np.flip(iterative_std1), label='Iterative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 3/2, np.flip(ssim_mean55), width, yerr=np.flip(ssim_std55), label='End-to-End', ecolor='black', capsize=5)\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('MS-SSIM', fontsize=35)\n",
    "#ax.set_title('R=0.5' + ' with Different Poission Noise', fontsize=40)\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels, fontsize=25)\n",
    "ax.tick_params(axis='both', labelsize=35)\n",
    "ax.legend(fontsize=39)\n",
    "ax.patch.set_edgecolor('black')  \n",
    "ax.patch.set_linewidth('3')\n",
    "\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom')\n",
    "\n",
    "#autolabel(rects1)\n",
    "#autolabel(rects2)\n",
    "fig.tight_layout()\n",
    "plt.ylim([0, 1])\n",
    "#plt.yscale('log')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b284927b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def norm_to_one(tensor):\n",
    "    tf_max = np.max(tensor)\n",
    "    tf_min = np.min(tensor)\n",
    "    return (tensor - tf_min) / (tf_max - tf_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3959dff4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.patches as patches\n",
    "plt.rcParams['figure.figsize'] = [20+4, 20]\n",
    "fig = plt.figure()\n",
    "fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "background = 0\n",
    "circuit = 108\n",
    "\n",
    "for i in range(24):\n",
    "    photon_level = [1e3, 100, 10, 1][i//6]\n",
    "    alpha = np.flip([0.03125, 0.5, 0.25, 1])[i//6]\n",
    "    \n",
    "    prefix = data_folder_prefix + '/Simulated/R_Sweep/'\n",
    "    test_images = np.load(prefix + 'test_images-R-0.50.npy')\n",
    "    # get the phase of the complex objects\n",
    "    test_images = np.angle(test_images)\n",
    "    \n",
    "    prox_1 = np.load('test-approx-%d-iter-1-lr-1.00.npy' \n",
    "                             % photon_level).astype(np.float32)/np.pi\n",
    "    \n",
    "    mean1 = np.mean(prox_1[:, remove:-remove, remove:-remove], axis=(1, 2))\n",
    "    mean2 = np.mean(test_images[:, remove:-remove, remove:-remove], axis=(1, 2))\n",
    "    b1 = mean2 - mean1\n",
    "    #prox_1 = prox_1 + b\n",
    "    #prox_1[:, Rs>circuit] = background\n",
    "    \n",
    "    prox_100 = np.load('test-approx-%d-iter-100-lr-0.50.npy' \n",
    "                             % photon_level).astype(np.float32)/np.pi\n",
    "    mean1 = np.mean(prox_100[:, remove:-remove, remove:-remove], axis=(1, 2))\n",
    "    mean2 = np.mean(test_images[:, remove:-remove, remove:-remove], axis=(1, 2))\n",
    "    b100 = mean2 - mean1\n",
    "    #prox_100 = prox_100 + b\n",
    "    #prox_100[:, Rs>circuit] = background\n",
    "    \n",
    "    e2e = np.load('exp-End-to-End-test-output-R-0.5-photon-' + str(photon_level) + '.npy')\n",
    "    e2e[:, Rs>circuit] = background\n",
    "    \n",
    "    not_gen = loadmat('2021-10-05exp-pretrained-R-0.5-peak-%d.mat' % photon_level)['rec_test_output']\n",
    "    not_gen[:, Rs>circuit] = background\n",
    "    \n",
    "    gen_not_pre = loadmat('2021-10-05-exp-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))['rec_test_output']\n",
    "    gen_not_pre[:, Rs>circuit] = background\n",
    "    \n",
    "\n",
    "    \n",
    "    idx = 3\n",
    "    circle = plt.Circle((5, 5), 0.5, color='b', fill=False)\n",
    "    ax = fig.add_subplot(5, 6, i +1).add_patch(circle)\n",
    "    ax.axes.xaxis.set_visible(False)\n",
    "    ax.axes.yaxis.set_visible(False)\n",
    "    #ax.set_clip_path(patch)\n",
    "    if i%6 == 0 :\n",
    "        im = prox_1[idx] + b1[idx]\n",
    "        im[Rs>circuit] = background\n",
    "        plt.imshow(np.rot90(im[20:-20, 20:-20], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 1 :\n",
    "        im = prox_100[idx] + b100[idx]\n",
    "        im[Rs>circuit] = background\n",
    "        plt.imshow(np.rot90(im[20:-20, 20:-20], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 2 :\n",
    "        plt.imshow(np.rot90(e2e[idx][20:-20, 20:-20], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 3 :\n",
    "        plt.imshow(np.rot90(not_gen[idx][20:-20, 20:-20], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 4 :\n",
    "        plt.imshow(np.rot90(gen_not_pre[idx][20:-20, 20:-20], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 5 :\n",
    "        test_image = test_images[idx]\n",
    "        test_image[Rs>circuit] = background\n",
    "        plt.imshow(np.rot90(test_image[20:-20, 20:-20], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e21cefbc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
